from typing import List, Union, Tuple

from Voicelab.pipeline.Node import Node
from parselmouth.praat import call
from Voicelab.toolkits.Voicelab.VoicelabNode import VoicelabNode

###################################################################################################
# MANIPULATE PITCH NODE
# WARIO pipeline node for manipulating the pitch of a voice.
###################################################################################################
# ARGUMENTS
# 'voice'   : sound file generated by parselmouth praat
###################################################################################################
# RETURNS
###################################################################################################


class ManipulatePitchAndFormantsNode(VoicelabNode):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.args = {
            "unit": ("ERB", ["ERB", "Hertz", "mel", "logHertz", "semitones"]),
            "Pitch Method": ("Shift", ["Shift frequencies", "Multiply frequencies"]),
            "Pitch Amount": 0.5,
            "time_step": 0.001,
            "Formant scalar (%)": 15,
            "Lower Pitch": True,
            "Raise Pitch": True,
            "Lower Formants": True,
            "Raise Formants": True,
            "Lower Pitch and Formants": True,
            "Raise Pitch and Formants": True,
        }

    ###############################################################################################
    # process: WARIO hook called once for each voice file.
    ###############################################################################################

    def process(self):
        list_of_sounds_to_save = []
        self.initialize_values(self)

        if self.args["Lower Pitch"]:
            lower_pitch_sound = self.manipulate_pitch()
            list_of_sounds_to_save.append()
        if self.args["Raise Pitch"]:
            raise_pitch_sound = self.manipulate_pitch()


    def initialize_values(self):
        if self.args["Lower Pitch"]:
            if self.args["Pitch method"][0] == "Shift frequencies":
                self.args["Pitch Amount"] = -1 * self.args["Pitch Amount"]
            else:
                if self.args["Pitch Amount"] > 1:
                    self.args["Pitch Amount"] = self.args["Pitch Amount"] / 100
                else:
                    self.args["Pitch Amount"] = self.args["Pitch Amount"]
            self.args["Formant scalar (%)"] = 1
            self.manipulate_pitch_and_formants()

        elif self.args["Raise Pitch"]:
            if self.args["Pitch method"][0] == "Shift frequencies":
                self.args["Pitch Amount"] = self.args["Pitch Amount"]
            else:
                if self.args["Pitch Amount"] > 1:
                    self.args["Pitch Amount"] = self.args["Pitch Amount"] / 100
            self.args["Formant scalar (%)"] = 1
            self.manipulate_pitch_and_formants()

        if self.args["Lower Formants"]:
            pass


    def manipulate_pitch_and_formants(self):
        sound = self.args["voice"]
        file_path = self.args["file_path"]
        time_step = self.args["time_step"]
        formant_factor = self.args["Formant scalar (%)"]
        pitch_factor = self.args["pitch_factor"]
        duration = self.args["duration"]


        f0min, f0max = self.pitch_bounds(sound)
        #  lower the lower bound to accommodate the manipulation
        f0min = (f0min - self.args["Pitch Amount"]) * 0.9


        pitch_expression = f'self + {pitch_factor}'

        # formant factor is in format e.g. down 5%
        vtl_factor = (
            formant_factor / 100
        )

        sampling_rate = call(sound, "Get sample rate")

        # create Pitch & Manipulation objects
        pitch = call(sound, "To Pitch", 0.001, f0min, f0max)

        # pitch = sound.to_pitch(0.001, f0min, f0max)
        manipulation = call([pitch, sound], "To Manipulation")

        # apply the appropriate transformation to the Pitch object
        # includes vtfactor because of subsequent rescaling
        pitch_formula = f'{pitch_expression} * {vtl_factor}'
        pitch = call(pitch, "Formula", pitch_formula)

        # turn it into a PitchTier and place it into the Analysis object
        pitch_tier = call([manipulation, pitch], "Down to pitch tier")
        manipulation = call([pitch_tier, manipulation], "Replace pitch tier")

        # change to new duration
        duration_tier = call(sound, "Create DurationTier", "DurationTier", 0, duration)
        duration_tier.add_point(0, 1 / vtl_factor)
        manipulation = call([duration_tier, manipulation], "Replace duration tier")
        manipulated_sound = call(manipulation, "Get resynthesis (overlap-add)")
        manipulated_sound.override_sampling_frequency(sampling_rate)
        if self.args["normalize amplitude"]:
            manipulated_sound.scale_intensity(70)

        manipulated_pitch_and_formant_name = sound.name + \
                                             f"_formant_manipulated_pitch_{pitch_factor}" \
                                             f"_formants_{vtl_factor_percent}.wav"
        manipulated_sound.save(manipulated_pitch_and_formant_name, "WAV")


        return manipulated_sound
